{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we are going to demonstrate some basic tasks in graph learning. In general, many of the graph learning problems can fall into the following categories:\n",
    "\n",
    "* Node classification: assign a label to a node.\n",
    "* Link prediction: predict the existence of an edge between two nodes.\n",
    "* Graph classification: assign a label to a graph.\n",
    "\n",
    "Many real-world applications can be formulated as one of these graph problems.\n",
    "* Fraud detection in financial transactions: transactions form a graph, where users are nodes and transactions are edges. In this case, we want to detect malicious users, which is to assign binary labels to users.\n",
    "* Community detection in a social network: a social network is naturally a graph, where nodes are users and edges are interactions between users. We want to predict which community a node belongs to.\n",
    "* Recommendation: users and items form a bipartite graph. They are connected with edges when users purchase items. Given users' purchase history, we want to predict what items a user will purchase in a near future. Thus, recommendation is a link prediction problem.\n",
    "* Drug discovery: a molecule is a graph whose nodes are atoms. We want to predict the property of a molecule. In this case, we want to assign a label to a graph."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get started\n",
    "\n",
    "DGL can be used with different deep learning frameworks. Currently, DGL can be used with Pytorch and MXNet. Here, we show how DGL works with Pytorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import gluon"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we load DGL, we need to set the DGL backend for one of the deep learning frameworks. Because this tutorial develops models in Pytorch, we have to set the DGL backend to Pytorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "from dgl import DGLGraph\n",
    "\n",
    "# Load MXNet as backend\n",
    "dgl.load_backend('mxnet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the rest of the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GNN model\n",
    "\n",
    "Typically, GNN is used to compute meaningful node embeddings. With the embeddings, we can perform many downstream tasks.\n",
    "\n",
    "DGL provides two ways of implementing a GNN model:\n",
    "* using the [nn module](https://doc.dgl.ai/features/nn.html), which contains many commonly used GNN modules.\n",
    "* using the message passing interface to implement a GNN model from scratch.\n",
    "\n",
    "For simplicity, we implement the GNN model in the tutorial with the nn module."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we use [GCN](https://arxiv.org/abs/1609.02907), one of the first inductive GNN models. GCN performs the following computation on every node $v$ in the graph:\n",
    "\n",
    "$$h_{N(v)}^{(l)} \\gets \\Sigma_{u \\in N(v)}{h_u^{(l-1)} + h_v^{(l-1)}}$$\n",
    "$$h_v^{(l)} \\gets \\sigma(W^k \\cdot \\frac{h_{N(v)}^{(l)}}{d_v + 1}),$$\n",
    "\n",
    "where $N(v)$ is the neighborhood of node $v$ and $l$ is the layer Id."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The GCN model has multiple layers. In each layer, a vertex accesses its direct neighbors. When we stack $k$ layers in a model, a node $v$ access neighbors within $k$ hops. The output of the GCN model is node embeddings that represent the nodes and all information in the k-hop neighborhood.\n",
    "\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/GNN.png\" alt=\"drawing\" width=\"600\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use DGL's `nn` module to build the GCN model. `GraphConv` implements the operations of GCN in a layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.nn.mxnet import conv as dgl_conv\n",
    "\n",
    "class GCNModel(gluon.Block):\n",
    "    def __init__(self,\n",
    "                 in_feats,\n",
    "                 n_hidden,\n",
    "                 out_dim,\n",
    "                 n_layers,\n",
    "                 activation,\n",
    "                 dropout):\n",
    "        super(GCNModel, self).__init__()\n",
    "        self.layers = gluon.nn.Sequential()\n",
    "\n",
    "        # input layer\n",
    "        self.layers.add(dgl_conv.GraphConv(in_feats, n_hidden, activation=activation))\n",
    "        # hidden layers\n",
    "        for i in range(n_layers - 1):\n",
    "            self.layers.add(dgl_conv.GraphConv(n_hidden, n_hidden, activation=activation))\n",
    "        # output layer\n",
    "        self.layers.add(dgl_conv.GraphConv(n_hidden, out_dim, activation=None))\n",
    "        self.dropout = gluon.nn.Dropout(rate=dropout)\n",
    "\n",
    "    def forward(self, g, features):\n",
    "        h = features\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            if i != 0:\n",
    "                h = self.dropout(h)\n",
    "            h = layer(g, h)\n",
    "        return h"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Interested readers can check out our [online tutorials](https://doc.dgl.ai/tutorials/models/index.html) to see how to use DGL's message passing interface to implement GNN models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the dataset for the tutorial\n",
    "\n",
    "DGL has a large collection of built-in datasets. Please see [this doc](https://doc.dgl.ai/api/python/data.html) for more information.\n",
    "\n",
    "In this tutorial, we use a citation network called pubmed for demonstration. A node in the citation network is a paper and an edge represents the citation between two papers. This dataset has 19,717 papers and 88,651 citations. Each paper has a sparse bag-of-words feature vector and a class label.\n",
    "\n",
    "All other graph data, such as node features, are stored as NumPy tensors. When we load the tensors, we convert them to Pytorch tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dgl.data import citegrh\n",
    "\n",
    "# load and preprocess the pubmed dataset\n",
    "data = citegrh.load_pubmed()\n",
    "\n",
    "# sparse bag-of-words features of papers\n",
    "features = mx.nd.array(data.features)\n",
    "# the number of input node features\n",
    "in_feats = features.shape[1]\n",
    "# class labels of papers\n",
    "labels = mx.nd.array(data.labels)\n",
    "# the number of unique classes on the nodes.\n",
    "n_classes = data.num_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For small datasets, DGL stores the network structure in a [NetworkX](https://networkx.github.io) object. NetworkX is a very popular Python graph library. It provides comprehensive API for graph manipulation and is very useful for preprocessing small graphs.\n",
    "\n",
    "Here we use NetworkX to remove all self-loops in the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = data.graph\n",
    "g.remove_edges_from(g.selfloop_edges())\n",
    "g.add_edges_from(zip(g.nodes(), g.nodes()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we create a DGLGraph from the grpah dataset and convert it to a read-only DGLGraph, which supports more efficient computation. Currently, DGL sampling API only works on read-only DGLGraphs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = DGLGraph(data.graph)\n",
    "g.readonly()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Node classification in the semi-supervised setting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us perform node classification in a semi-supervised setting. In this setting, we have the entire graph structure and all node features. We only have labels on some of the nodes. We want to predict the labels on other nodes. Even though some of the nodes do not have labels, they connect with nodes with labels. Thus, we train the model with both labeled nodes and unlabeled nodes. Semi-supervised learning can usually improve performance.\n",
    "\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/node_classify1.png\" alt=\"drawing\" width=\"200\"/>\n",
    "\n",
    "This dependency graph shows a better view of how labeled and unlabled nodes are used in the training.\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/node_classify2.png\" alt=\"drawing\" width=\"800\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we create a 2-layer GCN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "n_hidden = 16\n",
    "n_layers = 1\n",
    "dropout = 0.5\n",
    "aggregator_type = 'gcn'\n",
    "\n",
    "gconv_model = GCNModel(in_feats,\n",
    "                       n_hidden,\n",
    "                       n_classes,\n",
    "                       n_layers,\n",
    "                       mx.nd.relu,\n",
    "                       dropout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we create the node classification model based on the GCN model. The GCN model takes a DGLGraph object and node features as input and computes node embeddings as output. With node embeddings, we use a cross entropy loss to train the node classification model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we create the node classification model based on the GCN model. The GCN model takes a DGLGraph object and node features as input and computes node embeddings as output. With node embeddings, we use a cross entropy loss to train the node classification model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NodeClassification(gluon.Block):\n",
    "    def __init__(self, gconv_model, n_hidden, n_classes):\n",
    "        super(NodeClassification, self).__init__()\n",
    "        self.gconv_model = gconv_model\n",
    "        self.loss_fcn = gluon.loss.SoftmaxCELoss()\n",
    "\n",
    "    def forward(self, g, features, train_mask):\n",
    "        logits = self.gconv_model(g, features)\n",
    "        return self.loss_fcn(logits, labels, mx.nd.expand_dims(train_mask, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After defining a model for node classification, we need to define an evaluation function to evaluate the performance of a trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def NCEvaluate(model, g, features, labels, mask):\n",
    "    # compute embeddings with GNN\n",
    "    pred = model.gconv_model(g, features).argmax(axis=1)\n",
    "    accuracy = ((pred == labels) * mask).sum() / mask.sum().asscalar()\n",
    "    return accuracy.asscalar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prepare data for semi-supervised node classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the dataset is split into training set, validation set and testing set.\n",
    "train_mask = mx.nd.array(data.train_mask)\n",
    "val_mask = mx.nd.array(data.val_mask)\n",
    "test_mask = mx.nd.array(data.test_mask)\n",
    "    \n",
    "print(\"\"\"----Data statistics------'\n",
    "      #Classes %d\n",
    "      #Train samples %d\n",
    "      #Val samples %d\n",
    "      #Test samples %d\"\"\" %\n",
    "          (n_classes,\n",
    "           train_mask.sum().asscalar(),\n",
    "           val_mask.sum().asscalar(),\n",
    "           test_mask.sum().asscalar()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After defining the model and evaluation function, we can put everything into the training loop to train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Node classification task\n",
    "model = NodeClassification(gconv_model, n_hidden, n_classes)\n",
    "model.initialize(ctx=mx.cpu())\n",
    "# Training hyperparameters\n",
    "weight_decay = 5e-4\n",
    "n_epochs = 150\n",
    "lr = 3e-2\n",
    "n_train_samples = train_mask.sum().asscalar()\n",
    "\n",
    "# create the Adam optimizer\n",
    "trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': lr, 'wd': weight_decay})\n",
    "\n",
    "dur = []\n",
    "for epoch in range(n_epochs):\n",
    "    # forward\n",
    "    with mx.autograd.record():\n",
    "        loss = model(g, features, train_mask)\n",
    "        loss = loss.sum() / n_train_samples\n",
    "    loss.backward()\n",
    "    trainer.step(batch_size=1)\n",
    "\n",
    "    acc = NCEvaluate(model, g, features, labels, val_mask)\n",
    "    print(\"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f}\"\n",
    "          .format(epoch, loss.asnumpy()[0], acc))\n",
    "\n",
    "print()\n",
    "acc = NCEvaluate(model, g, features, labels, test_mask)\n",
    "print(\"Test Accuracy {:.4f}\".format(acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Link prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The task of link prediction is to predict the existence of an edge between two nodes.\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/link_predict1.png\" alt=\"drawing\" width=\"500\"/>\n",
    "\n",
    "In general, we consider that an edge connects two similar nodes. Thus, link prediction is to find pairs of similar nodes.\n",
    "\n",
    "Traditional methods, such as SVD, only takes the graph structure into consideration for link prediction. GNN-based models can take both the graph structure and node features into consideration when estimating if two nodes are connected. Thus, GNN-based models can potentially generate better results.\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/link_predict2.png\" alt=\"drawing\" width=\"300\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like the node classification task, we first use GCN as the base model to compute node embeddings in this task as well. Similarly, we can replace GCN with many other GNN models to compute node embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Model hyperparameters\n",
    "n_hidden = in_feats\n",
    "n_layers = 1\n",
    "dropout = 0.5\n",
    "aggregator_type = 'gcn'\n",
    "\n",
    "# create GCN model\n",
    "gconv_model = GCNModel(in_feats,\n",
    "                       n_hidden,\n",
    "                       n_hidden,\n",
    "                       n_layers,\n",
    "                       mx.nd.relu,\n",
    "                       dropout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unlike node classification, which uses node labels as the training supervision, link prediction simply uses connectivity of nodes as the training signal: nodes connected by edges are similar, while nodes not connected by edges are dissimilar."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To train such a model, we need to deploy negative sampling to construct negative samples. In the context of link prediction, a positive sample is an edge that exist in the original graph, while a negative sample is a pair of nodes that don't have an edge between them in the graph. We usually train on each positive sample with multiple negative samples.\n",
    "\n",
    "After having the node embeddings, we compute the similarity scores on positive samples and negative samples. We construct the following loss function on a positive sample and the corresponding negative samples:\n",
    "\n",
    "$$L = -log(\\sigma(z_u^T z_v)) - Q \\cdot E_{v_n \\sim P_n(v)}(log(\\sigma(-z_u^T z_{v_n}))),$$\n",
    "\n",
    "where $Q$ is the number of negative samples.\n",
    "\n",
    "With this loss, training should increase the similarity scores on the positive samples and decrease the similarity scores on negative samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def logsigmoid(val):\n",
    "    max_elem = mx.nd.maximum(0., -val)\n",
    "    z = mx.nd.exp(-max_elem) + mx.nd.exp(-val - max_elem)\n",
    "    return -(max_elem + mx.nd.log(z))\n",
    "\n",
    "# NCE loss\n",
    "def NCE_loss(pos_score, neg_score, neg_sample_size):\n",
    "    pos_score = logsigmoid(pos_score)\n",
    "    neg_score = logsigmoid(-neg_score).reshape(-1, neg_sample_size)\n",
    "    return -pos_score - mx.nd.sum(neg_score, axis=1)\n",
    "\n",
    "class LinkPrediction(gluon.Block):\n",
    "    def __init__(self, gconv_model):\n",
    "        super(LinkPrediction, self).__init__()\n",
    "        self.gconv_model = gconv_model\n",
    "\n",
    "    def forward(self, g, features, neg_sample_size):\n",
    "        emb = self.gconv_model(g, features)\n",
    "        pos_g, neg_g = edge_sampler(g, neg_sample_size, return_false_neg=False)\n",
    "        pos_score = score_func(pos_g, emb)\n",
    "        neg_score = score_func(neg_g, emb)\n",
    "        return mx.nd.mean(NCE_loss(pos_score, neg_score, neg_sample_size))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL provides an edge sampler `EdgeSampler`, which selects positive edge samples and negative edge samples efficiently. Thus, we can use it to generate positive samples and negative samples for link prediction. `EdgeSampler` generates `neg_sample_size` negative edges by corrupting the head or the tail node of a positive edge with some randomly sampled nodes.\n",
    "\n",
    "<img src=\"https://github.com/zheng-da/DGL_devday_tutorial/raw/master/negative_edges.png\" alt=\"drawing\" width=\"400\"/>\n",
    "\n",
    "`edge_sampler` samples one tenth of the edges in the graph as positive edges. It returns a positive subgraph and a negative subgraph. The positive subgraph contains all positive edges sampled from the graph `g`, while the negative subgraph contains all negative edges constructed by the edge sampler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def edge_sampler(g, neg_sample_size, edges=None, return_false_neg=True):\n",
    "    sampler = dgl.contrib.sampling.EdgeSampler(g, batch_size=int(g.number_of_edges()/10),\n",
    "                                               seed_edges=edges,\n",
    "                                               neg_sample_size=neg_sample_size,\n",
    "                                               negative_mode='tail',\n",
    "                                               shuffle=True,\n",
    "                                               return_false_neg=return_false_neg)\n",
    "    sampler = iter(sampler)\n",
    "    return next(sampler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After having the positive edge subgraph and the negative edge subgraph, we can now compute the similarity on the positive edge samples and negative edge samples.\n",
    "\n",
    "In this tutorial, we use cosine similarity to measure the similarity between two nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_func(g, emb):\n",
    "    src_nid, dst_nid = g.all_edges(order='eid')\n",
    "    # Get the node Ids in the parent graph.\n",
    "    src_nid = g.parent_nid[src_nid]\n",
    "    dst_nid = g.parent_nid[dst_nid]\n",
    "    # Read the node embeddings of the source nodes and destination nodes.\n",
    "    pos_heads = emb[src_nid]\n",
    "    pos_tails = emb[dst_nid]\n",
    "    # cosine similarity\n",
    "    return mx.nd.sum(pos_heads * pos_tails, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like node classification, we define an evaluation function to evaluate the training result. Ideally, the similarity score of a positive edge should be higher than all negative edges.\n",
    "\n",
    "We evaluate the performance of link prediction with MRR (mean reciprocal rank):\n",
    "$$MRR=\\frac{1}{|Q|} \\Sigma_{i=1}^{|Q|} \\frac{1}{rank_i}.$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def LPEvaluate(gconv_model, g, features, eval_eids, neg_sample_size):\n",
    "    emb = gconv_model(g, features)\n",
    "\n",
    "    pos_g, neg_g = edge_sampler(g, neg_sample_size,\n",
    "                                mx.nd.array(eval_eids, dtype=test_eids.dtype),\n",
    "                                return_false_neg=True)\n",
    "    pos_score = score_func(pos_g, emb)\n",
    "    neg_score = score_func(neg_g, emb).reshape(-1, neg_sample_size)\n",
    "    filter_bias = neg_g.edata['false_neg'].reshape(-1, neg_sample_size)\n",
    "\n",
    "    pos_score = logsigmoid(pos_score)\n",
    "    neg_score = logsigmoid(neg_score)\n",
    "    neg_score -= filter_bias.astype(np.float32)\n",
    "    pos_score = mx.nd.expand_dims(pos_score, axis=1)\n",
    "    rankings = mx.nd.sum(neg_score >= pos_score, axis=1) + 1\n",
    "    return np.mean(1.0/rankings.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's get to the actual training.\n",
    "\n",
    "We first split the graph into the training set and the testing set. We random sample 80\\% edges as the training data and 20\\% edges as the testing data. There is no overlap between the training set and the testing set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eids = np.random.permutation(g.number_of_edges())\n",
    "train_eids = eids[:int(len(eids) * 0.8)]\n",
    "valid_eids = eids[int(len(eids) * 0.8):int(len(eids) * 0.9)]\n",
    "test_eids = eids[int(len(eids) * 0.9):]\n",
    "train_g = g.edge_subgraph(train_eids, preserve_nodes=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Model for link prediction\n",
    "model = LinkPrediction(gconv_model)\n",
    "model.initialize()\n",
    "\n",
    "# Training hyperparameters\n",
    "weight_decay = 5e-4\n",
    "n_epochs = 10\n",
    "lr = 1e-4\n",
    "neg_sample_size = 100\n",
    "\n",
    "# use optimizer\n",
    "trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': lr, 'wd': weight_decay})\n",
    "\n",
    "# initialize graph\n",
    "dur = []\n",
    "for epoch in range(n_epochs):\n",
    "    with mx.autograd.record(): \n",
    "        loss = model(train_g, features, neg_sample_size)\n",
    "    loss.backward()\n",
    "    trainer.step(batch_size=1)\n",
    "    acc = LPEvaluate(gconv_model, g, features, valid_eids, neg_sample_size)\n",
    "    print(\"Epoch {:05d} | Loss {:.4f} | MRR {:.4f}\".format(epoch, loss.asscalar(), acc))\n",
    "\n",
    "print()\n",
    "# Let's save the trained node embeddings.\n",
    "acc = LPEvaluate(gconv_model, g, features, test_eids, neg_sample_size)\n",
    "print(\"Test MRR {:.4f}\".format(acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Take home exercise\n",
    "\n",
    "An interested user can try other GNN models to compute node embeddings and use it for node classification. Please check out the [nn module](https://doc.dgl.ai/features/nn.html) in DGL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
